{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "f802e64d-4eaa-445d-a48a-1042a91bc394",
   "metadata": {
    "tags": []
   },
   "source": [
    "# GBI\n",
    "\n",
    "## 目标\n",
    "通过 GBI sdk 接口完成选表和问表的能力。 \n",
    "\n",
    "## 准备工作\n",
    "### 平台注册\n",
    "1.先在appbuilder平台注册，获取token\n",
    "2.安装appbuilder-sdk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2939356f-61c2-42e9-9e0c-fc6729c193f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install appbuilder-sdk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4ccff03b-1567-4e8b-8e1f-9a5032690406",
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "\n",
    "#  设置环境变量\n",
    "os.environ[\"APPBUILDER_TOKEN\"] = \"***\"\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aeb2fa55-075f-48df-a9fb-8b40d9900684",
   "metadata": {},
   "source": [
    "## 开发过程"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c3c5cee",
   "metadata": {},
   "source": [
    "### 设置表的 schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d7d6440c",
   "metadata": {},
   "outputs": [],
   "source": [
    "SUPER_MARKET_SCHEMA = \"\"\"\n",
    "```\n",
    "CREATE TABLE `supper_market_info` (\n",
    "  `订单编号` varchar(32) DEFAULT NULL,\n",
    "  `订单日期` date DEFAULT NULL,\n",
    "  `邮寄方式` varchar(32) DEFAULT NULL,\n",
    "  `地区` varchar(32) DEFAULT NULL,\n",
    "  `省份` varchar(32) DEFAULT NULL,\n",
    "  `客户类型` varchar(32) DEFAULT NULL,\n",
    "  `客户名称` varchar(32) DEFAULT NULL,\n",
    "  `商品类别` varchar(32) DEFAULT NULL,\n",
    "  `制造商` varchar(32) DEFAULT NULL,\n",
    "  `商品名称` varchar(32) DEFAULT NULL,\n",
    "  `数量` int(11) DEFAULT NULL,\n",
    "  `销售额` int(11) DEFAULT NULL,\n",
    "  `利润` int(11) DEFAULT NULL\n",
    ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4\n",
    "```\n",
    "\"\"\"\n",
    "\n",
    "PRODUCT_SALES_INFO = \"\"\"\n",
    "现有 mysql 表 product_sales_info, \n",
    "该表的用途是: 产品收入表\n",
    "```\n",
    "CREATE TABLE `product_sales_info` (\n",
    "  `年` int,\n",
    "  `月` int,\n",
    "  `产品名称` varchar,\n",
    "  `收入` decimal,\n",
    "  `非交付成本` decimal,\n",
    "  `含交付毛利` decimal\n",
    ")\n",
    "```\n",
    "\"\"\"\n",
    "\n",
    "# schema 和表名的映射\n",
    "SCHEMA_MAPPING = {\n",
    "    \"supper_market_info\": SUPER_MARKET_SCHEMA,\n",
    "    \"PRODUCT_SALES_INFO\": PRODUCT_SALES_INFO\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "463254a1",
   "metadata": {},
   "source": [
    "设置表的描述用于选表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7fefcae1",
   "metadata": {},
   "outputs": [],
   "source": [
    "table_descriptions = {\n",
    "    \"supper_market_info\": \"超市营收明细表，包含超市各种信息等\",\n",
    "    \"product_sales_info\": \"产品销售表\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0aff843",
   "metadata": {},
   "source": [
    "### 选表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "41559341-fd7a-478c-a08b-1477d79e9d41",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-12-18T06:24:26.982459Z",
     "start_time": "2023-12-18T06:23:53.771345Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "选的表是: ['supper_market_info']\n"
     ]
    }
   ],
   "source": [
    "import appbuilder\n",
    "from appbuilder.core.message import Message\n",
    "from appbuilder.core.components.gbi.basic import SessionRecord\n",
    "\n",
    "# 生成问表对象\n",
    "select_table = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", table_descriptions=table_descriptions)\n",
    "query = \"列出超市中的所有数据\"\n",
    "msg = Message({\"query\": query})\n",
    "select_table_result_message = select_table(msg)\n",
    "print(f\"选的表是: {select_table_result_message.content}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16a8aa38-7a33-4e27-bca4-00900cfe1641",
   "metadata": {},
   "source": [
    "### 问表\n",
    "基于上面选出的表，通过获取表的 schema 进行问表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9f45ef5f-6206-4b31-83c4-3c8eb2c86925",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sql: SELECT * FROM supper_market_info;\n",
      "-----------------\n",
      "llm result: ```sql\n",
      "SELECT * FROM supper_market_info;\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "table_schemas = [SCHEMA_MAPPING[table_name] for table_name in select_table_result_message.content]\n",
    "gbi_nl2sql = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas)\n",
    "nl2sql_result_message = gbi_nl2sql(Message({\"query\": \"列出超市中的所有数据\"}))\n",
    "print(f\"sql: {nl2sql_result_message.content.sql}\")\n",
    "print(\"-----------------\")\n",
    "print(f\"llm result: {nl2sql_result_message.content.llm_result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0409c46-e8c7-403a-a827-fcdc8e717be6",
   "metadata": {},
   "source": [
    "设置 session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a23b8cad-f426-4074-9311-c2c33aaea07b",
   "metadata": {},
   "outputs": [],
   "source": [
    "session = list()\n",
    "session.append(SessionRecord(query=query, answer=nl2sql_result_message.content))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22b3d877-f61f-4958-a084-7507a3017e17",
   "metadata": {},
   "source": [
    "再次问表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2adcb091-fb53-4364-b4d8-20564439ff51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sql: SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
      "-----------------\n",
      "llm result: ```sql\n",
      "SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\", \n",
    "                                             \"session\": session}))\n",
    "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n",
    "print(\"-----------------\")\n",
    "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e0609ae-f2bc-43d3-9023-14e9f8618158",
   "metadata": {},
   "source": [
    "### 增加列选优化\n",
    "实际上数据中 \"商品类别\" 存储的是 \"新鲜水果\", 那么就可以通过列选的限制来优化 sql."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2a7c7923-019e-4660-9e36-4431e9d2f3a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sql: SELECT * FROM supper_market_info WHERE 商品类别 = '新鲜水果'\n",
      "-----------------\n",
      "llm result: ```sql\n",
      "SELECT * FROM supper_market_info WHERE 商品类别 = '新鲜水果'\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "from appbuilder.core.components.gbi.basic import ColumnItem\n",
    "\n",
    "\n",
    "column_constraint = [ColumnItem(ori_value=\"水果\", \n",
    "                               column_name=\"商品类别\", \n",
    "                               column_value=\"新鲜水果\", \n",
    "                               table_name=\"supper_market_info\", \n",
    "                               is_like=False)]\n",
    "\n",
    "nl2sql_result_message2 = gbi_nl2sql(Message({\"query\": \"查看商品类别是水果的所有数据\",\n",
    "                                    \"column_constraint\": column_constraint}))\n",
    "\n",
    "print(f\"sql: {nl2sql_result_message2.content.sql}\")\n",
    "print(\"-----------------\")\n",
    "print(f\"llm result: {nl2sql_result_message2.content.llm_result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8385312c-aea1-42cd-b61b-a8d36f4f0665",
   "metadata": {},
   "source": [
    "从上面我们看到，商品类别不在是 \"水果\" 而是 修订为 \"新鲜水果\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e98c414-8b2b-4187-a270-3117a4f431ff",
   "metadata": {},
   "source": [
    "### 增加知识优化\n",
    "当计算某些特殊知识的时候，大模型是不知道的，所以需要告诉大模型具体的知识，比如:\n",
    "利润率的计算方式: 利润/销售额\n",
    "可以将该知识注入。具体示例如下:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cade4693-29dc-431c-bf84-c6dc09104294",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 注入知识\n",
    "gbi_nl2sql.knowledge[\"利润率\"] = \"计算方式: 利润/销售额\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1dc181e8-47a1-4b82-8bb5-ce3339be53f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sql: SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n",
      "FROM supper_market_info\n",
      "WHERE 商品类别 = '日用品'\n",
      "GROUP BY 商品类别\n",
      "-----------------\n",
      "llm result: ```sql\n",
      "SELECT 商品类别, SUM(利润)/SUM(销售额) AS 利润率\n",
      "FROM supper_market_info\n",
      "WHERE 商品类别 = '日用品'\n",
      "GROUP BY 商品类别\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "nl2sql_result_message3 = gbi_nl2sql(Message({\"query\": \"列出商品类别是日用品的利润率\"}))\n",
    "print(f\"sql: {nl2sql_result_message3.content.sql}\")\n",
    "print(\"-----------------\")\n",
    "print(f\"llm result: {nl2sql_result_message3.content.llm_result}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5570cd9-dbaf-45cd-ab03-1a7f92e7d0d4",
   "metadata": {},
   "source": [
    "## 调整 prompt 模版\n",
    "有时候，我们希望定义自己的prompt, 选表和问表两个环节都支持 prompt 模版的定制化，但是必须遵循对应的 prompt 模版的格式。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e3d4967-2b4c-437d-9d72-fb1b94bdcf59",
   "metadata": {},
   "source": [
    "### 选表 prompt 调整\n",
    "选表的 prompt template, 必须包含 \n",
    "1. {num} - 表的数量， 注意 {num} 有两个地方出现\n",
    "2. {table_desc} - 表的描述\n",
    "3. {query} - query\n",
    "\n",
    "注意: {num}, {table_desc}, {query} 表示的是占位符，**用户不需要在自定义的 prompt template 中将这些值填充上**，gbi 系统会自动根据 SelectTable 构造函数中提交的参数进行填充这些占位符，从而产生最后的给大模型的 prompt。注意 prompt template 和 prompt 的区别。\n",
    "\n",
    "* prompt template - 是带有占位符的 prompt, gbi 会根据具体参数填充到这些占位符，形成最终的 prompt\n",
    "* promt - 是将 prompt template 填充完占位符的结果。\n",
    "\n",
    "用户可以使用这些占位符重新设置自己的 prompt 模版，从而达到修改 prompt 的目的。\n",
    "具体请参考下面的 prompt template 示例:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2ae6ffbc-4237-4fb2-8168-480b81bfd873",
   "metadata": {},
   "outputs": [],
   "source": [
    "SELECT_TABLE_PROMPT_TEMPLATE = \"\"\"\n",
    "你是一个专业的业务人员，下面有{num}张表，具体表名如下:\n",
    "{table_desc}\n",
    "请根据问题帮我选择上述1-{num}种的其中相关表并返回，可以为多表，也可以为单表,\n",
    "返回多张表请用“,”隔开\n",
    "返回格式请参考如下示例：\n",
    "问题:有多少个审核通过的投运单？\n",
    "回答: ```DWD_MAT_OPERATION```\n",
    "请严格参考示例只不要返回无关内容，直接给出最终答案后面的内容，分析步骤不要输出\n",
    "问题:{query}\n",
    "回答:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2bbbb375-6659-4ef0-82ff-a4ace9fdd4f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "选的表是: ['supper_market_info']\n"
     ]
    }
   ],
   "source": [
    "select_table4 = appbuilder.SelectTable(model_name=\"ERNIE-Bot 4.0\", \n",
    "                                          table_descriptions=table_descriptions,\n",
    "                                          prompt_template=SELECT_TABLE_PROMPT_TEMPLATE)\n",
    "\n",
    "select_table_result_message4 = select_table4(Message({\"query\":\"列出超市中的所有数据\"}))\n",
    "print(f\"选的表是: {select_table_result_message4.content}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f3fd089-613b-4bdd-95ac-c87f89c0fc61",
   "metadata": {},
   "source": [
    "## 问表 prompt 调整\n",
    "问表的 prompt template 必须包含:\n",
    "1. {schema} - 表的 schema 信息, gbi系统使用构建 NL2Sql 对象的 table_schemas 成员来填充该占位符，在构造函数中需要填充该参数。\n",
    "2. {instrument} - 列选限制的信息, gbi系统会使用 `NL2Sql.__call__` 函数的 Message 中的 column_constraint 参数来填充该占位符\n",
    "3. {kg} - 知识, gbi系统使用构建 NL2Sql 对象的 knowledge  成员来填充该占位符，在构造函数中需要填充该参数。\n",
    "4. {date} - 时间， gbi系统会自动填充该占位符, 用户不需要提供\n",
    "5. {history_prompt} - 历史, gbi系统会使用 `NL2Sql.__call__` 函数的  Message 中的 session 参数来填充该占位符\n",
    "6. {query} - 当前问题，  gbi系统会使用 `NL2Sql.__call__` 函数的 Message 中的 query 参数来填充该占位符\n",
    "\n",
    "\n",
    "注意: {schema}, {instrument}, {kg}, {date}, {history_prompt}, {query} 表示的是占位符，**用户不需要在自定义的 prompt template 中将这些值填充上**，gbi 系统会自动根据 `NL2Sql.__call__` 函数中提交的参数 或者 Nl2sql 的成员变量 进行填充这些占位符，从而产生最后的给大模型的 prompt。注意 prompt template 和 prompt 的区别。\n",
    "\n",
    "* prompt template - 是带有占位符的 prompt, gbi 会根据具体参数填充到这些占位符，形成最终的 prompt\n",
    "* promt - 是将 prompt template 填充完占位符的结果。\n",
    "\n",
    "用户可以使用这些占位符重新设置自己的 prompt 模版，从而达到修改 prompt 的目的。\n",
    "具体请参考下面的 prompt template 示例:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "323fbe75-62ca-44ab-9ca2-9f747939a2b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "NL2SQL_PROMPT_TEMPLATE = \"\"\"\n",
    "  MySql 表 Schema 如下:\n",
    "  {schema}\n",
    "  请根据用户当前问题，联系历史信息，仅编写1个sql，其中 sql 语句需要使用```sql ```这种 markdown 形式给出。\n",
    "  请参考列选信息：\n",
    "  {instrument}\n",
    "  请参考知识:\n",
    "  {kg}\n",
    "  当前时间：{date}\n",
    "  历史信息如下:\n",
    "  {history_prompt}\n",
    "  当前问题：\"{query}\"\n",
    "  回答：\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "52436f03-e01c-456a-aaa0-5a7f1afcd9d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sql: SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
      "-----------------\n",
      "llm result: ```sql\n",
      "SELECT * FROM supper_market_info WHERE 商品类别 = '水果'\n",
      "```\n"
     ]
    }
   ],
   "source": [
    "\n",
    "gbi_nl2sql5 = appbuilder.NL2Sql(model_name=\"ERNIE-Bot 4.0\", table_schemas=table_schemas, prompt_template=NL2SQL_PROMPT_TEMPLATE)\n",
    "nl2sql_result_message5 = gbi_nl2sql5(Message({\"query\": \"查看商品类别是水果的所有数据\"}))\n",
    "print(f\"sql: {nl2sql_result_message5.content.sql}\")\n",
    "print(\"-----------------\")\n",
    "print(f\"llm result: {nl2sql_result_message5.content.llm_result}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
